// Copyright 2022 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef XLS_JIT_FUNCTION_JIT_H_
#define XLS_JIT_FUNCTION_JIT_H_

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "llvm/include/llvm/IR/Function.h"
#include "xls/common/status/ret_check.h"
#include "xls/interpreter/evaluator_options.h"
#include "xls/ir/events.h"
#include "xls/ir/function.h"
#include "xls/ir/package.h"
#include "xls/ir/type.h"
#include "xls/ir/value.h"
#include "xls/jit/aot_entrypoint.pb.h"
#include "xls/jit/function_base_jit.h"
#include "xls/jit/jit_buffer.h"
#include "xls/jit/jit_callbacks.h"
#include "xls/jit/jit_evaluator_options.h"
#include "xls/jit/jit_runtime.h"
#include "xls/jit/observer.h"
#include "xls/jit/orc_jit.h"

namespace xls {

// This class provides a facility to execute XLS functions (on the host) by
// converting it to LLVM IR, compiling it, and finally executing it. Not
// thread-safe due to sharing of result and temporary buffers between
// invocations of Run.
class FunctionJit {
 public:
  // Returns an object containing a host-compiled version of the specified XLS
  // function.
  static absl::StatusOr<std::unique_ptr<FunctionJit>> Create(
      Function* xls_function,
      const EvaluatorOptions& options = EvaluatorOptions(),
      const JitEvaluatorOptions& jit_options = JitEvaluatorOptions());

  // Returns an object containing an AOT-compiled version of the specified XLS
  // function.
  static absl::StatusOr<std::unique_ptr<FunctionJit>> CreateFromAot(
      const AotEntrypointProto& entrypoint, std::string_view data_layout,
      JitFunctionType function_unpacked,
      std::optional<JitFunctionType> function_packed = std::nullopt,
      const EvaluatorOptions& options = EvaluatorOptions());

  // Returns the bytes of an object file containing the compiled XLS function.
  static absl::StatusOr<JitObjectCode> CreateObjectCode(
      Function* xls_function,
      const EvaluatorOptions& options = EvaluatorOptions(),
      const JitEvaluatorOptions& jit_options = JitEvaluatorOptions());

  // Executes the compiled function with the specified arguments.
  absl::StatusOr<InterpreterResult<Value>> Run(absl::Span<const Value> args);

  // As above, buth with arguments as key-value pairs.
  absl::StatusOr<InterpreterResult<Value>> Run(
      const absl::flat_hash_map<std::string, Value>& kwargs);

  // Executes the compiled function with the arguments and results specified as
  // "views" - flat buffers onto which structures layouts can be applied (see
  // value_view.h).
  //
  // Argument packing and unpacking into and out of LLVM-space can consume a
  // surprisingly large amount of execution time. Deferring such transformations
  // (and applying views) can eliminate this overhead and still give access tor
  // result data. Users needing less performance can still use the
  // Value-returning methods above for code simplicity.
  // TODO(https://github.com/google/xls/issues/506): 2021-10-13 Figure out
  // if we want a way to return events in the view and packed view interfaces
  // (or if their performance-focused execution means events are unimportant).
  template <bool kForceZeroCopy = false>
  absl::Status RunWithViews(absl::Span<uint8_t* const> args,
                            absl::Span<uint8_t> result_buffer,
                            InterpreterEvents* events);

  // Similar to RunWithViews(), except the arguments here are _packed_views_ -
  // views whose data elements are tightly packed, with no padding bits or bytes
  // between them. The function return value is specified as the last arg - its
  // storage must ALSO be pre-allocated before this call.
  //
  // Example (for a binary float32 operation):
  // float RunFloat(float a_f, float b_f);
  //  using PF32 = PackedTupleView<PackedBitsView<23>, ..., PackedBitsView<1>>;
  //  PF32 a(&a_f);
  //  PF32 b(&b_f);
  //  float x_f;
  //  PF32 x(&x_f);
  //  jit->RunWithPackedViews(a, b, x);
  //  return x_f;
  //
  // For most users, the autogenerated DSLX headers should be used as the JIT -
  // and especially the packed-view-using-call - interface; there are some
  // sharp edges here!
  template <typename... ArgsT>
  absl::Status RunWithPackedViews(ArgsT... args) {
    XLS_RET_CHECK(jitted_function_base_.HasPackedFunction());
    const uint8_t* arg_buffers[sizeof...(ArgsT)];
    uint8_t* result_buffer;
    // Walk the type tree to get each arg's data buffer into our view/arg list.
    PackArgBuffers(arg_buffers, &result_buffer, args...);

    InterpreterEvents events;
    uint8_t* output_buffers[1] = {result_buffer};
    jitted_function_base_.RunPackedJittedFunction(
        arg_buffers, output_buffers, temp_buffer_.get_base_pointer(), &events,
        /*instance_context=*/&callbacks_, runtime(), /*continuation_point=*/0);

    return InterpreterEventsToStatus(events);
  }

  // Same as RunWithPackedViews but expects a View rather than a PackedView.
  template <typename... ArgsT>
  absl::Status RunWithUnpackedViews(ArgsT... args) {
    return RunWithUnpackedViewsCommon</*kForceZeroCopy=*/false, ArgsT...>(
        args...);
  }

  // Same as RunWithPackedViews but expects a View rather than a PackedView.
  // Guaranteed to run without copying the arguments first. The arguments must
  // be aligned correctly.
  // NOTE: Alignment is determined by LLVM and might change with little warning.
  // TODO(allight): 2023-12-6 We need to make this more usable safely.
  template <typename... ArgsT>
  absl::Status RunWithUnpackedViewsZeroCopy(ArgsT... args) {
    return RunWithUnpackedViewsCommon</*kForceZeroCopy=*/true, ArgsT...>(
        args...);
  }

  const JittedFunctionBase& jitted_function_base() const {
    return jitted_function_base_;
  }

  // Gets the size of the compiled function's arguments (or return value) in the
  // native LLVM data layout (not the packed layout).
  int64_t GetArgTypeSize(int arg_index) const {
    return jitted_function_base_.GetInputBufferMetadata()[arg_index].size;
  }
  int64_t GetArgTypeAlignment(int arg_index) const {
    return jitted_function_base_.GetInputBufferMetadata()[arg_index]
        .abi_alignment;
  }
  int64_t GetReturnTypeSize() const {
    return jitted_function_base_.GetOutputBufferMetadata()[0].size;
  }
  int64_t GetReturnTypeAlignment() const {
    return jitted_function_base_.GetOutputBufferMetadata()[0].abi_alignment;
  }

  // Gets the size of the compiled function's arguments (or return value) in the
  // packed layout.
  int64_t GetPackedArgTypeSize(int arg_index) const {
    return jitted_function_base_.GetInputBufferMetadata()[arg_index]
        .packed_size;
  }
  int64_t GetPackedReturnTypeSize() const {
    return jitted_function_base_.GetOutputBufferMetadata()[0].packed_size;
  }

  // Returns the size of the temporary buffer which must be passed to the jitted
  // function. The buffer is used to hold temporary node values inside the
  // jitted function.
  int64_t GetTempBufferSize() const {
    return jitted_function_base_.temp_buffer_size();
  }

  int64_t GetTempBufferAlignment() const {
    return jitted_function_base_.temp_buffer_alignment();
  }

  // Returns the name of the jitted function.
  std::string_view GetJittedFunctionName() const {
    return jitted_function_base_.function_name();
  }

  JitRuntime* runtime() const { return jit_runtime_.get(); }

  RuntimeObserver* CurrentRuntimeObserver() const {
    return callbacks_.observer;
  }

  void ClearRuntimeObserver() { callbacks_.observer = nullptr; }
  // Set a callback to get notified on each node's evaluation.
  absl::Status SetRuntimeObserver(RuntimeObserver* observer) {
    if (!has_observer_callbacks_) {
      return absl::UnimplementedError("Observer callbacks not supported.");
    }
    callbacks_.observer = observer;
    return absl::OkStatus();
  }
  bool SupportsObservers() const { return has_observer_callbacks_; }

 private:
  struct InterfaceMetadata {
    std::string name;

    // Package only used as owner of types.
    std::unique_ptr<Package> package;
    std::vector<std::string> param_names;
    std::vector<Type*> param_types;
    Type* return_type;

    static absl::StatusOr<InterfaceMetadata> CreateFromFunction(
        Function* function);
    static absl::StatusOr<InterfaceMetadata> CreateFromAotEntrypoint(
        const AotEntrypointProto& entrypoint);

    int64_t ParamCount() const { return param_names.size(); }
  };

  FunctionJit(InterfaceMetadata&& metadata, std::unique_ptr<OrcJit>&& orc_jit,
              JittedFunctionBase&& jitted_function_base,
              bool has_observer_callbacks,
              std::unique_ptr<JitRuntime>&& runtime)
      : metadata_(std::move(metadata)),
        orc_jit_(std::move(orc_jit)),
        jitted_function_base_(std::move(jitted_function_base)),
        arg_buffers_(jitted_function_base_.CreateInputBuffer()),
        result_buffers_(jitted_function_base_.CreateOutputBuffer()),
        temp_buffer_(jitted_function_base_.CreateTempBuffer()),
        jit_runtime_(std::move(runtime)),
        has_observer_callbacks_(has_observer_callbacks) {}

  static absl::StatusOr<std::unique_ptr<FunctionJit>> CreateInternal(
      Function* xls_function, const EvaluatorOptions& options,
      const JitEvaluatorOptions& jit_options);

  template <bool kForceZeroCopy, typename... ArgsT>
  absl::Status RunWithUnpackedViewsCommon(ArgsT... args) {
    const uint8_t* arg_buffers[sizeof...(ArgsT)];
    uint8_t* result_buffer;

    // Walk the type tree to get each arg's data buffer into our view/arg list.
    PackArgBuffers(arg_buffers, &result_buffer, args...);

    InterpreterEvents events;
    InvokeUnalignedJitFunction<kForceZeroCopy>(arg_buffers, result_buffer,
                                               &events);
    return InterpreterEventsToStatus(events);
  }

  // Builds a function which wraps the natively compiled XLS function `callee`
  // (as built by xls::BuildFunction) with another function which accepts the
  // input arguments as an array of pointers to buffers and the output as a
  // pointer to a buffer. The input/output values are in the native LLVM data
  // layout. The function signature is:
  //
  //   void f(uint8_t*[] inputs, uint8_t* output,
  //          void* events, void* instance_context, void* jit_runtime)
  //
  // `inputs` is an array containing a pointer for each input argument. The
  // pointer points to a buffer containing the respective argument in the native
  // LLVM data layout.
  //
  // `outputs` points to an empty buffer appropriately sized to accept the
  // result in the native LLVM data layout.
  absl::StatusOr<llvm::Function*> BuildWrapper(llvm::Function* callee);

  // As BuildWrapper but the inputs and outputs are taken/returned in a packed
  // representation.
  absl::StatusOr<llvm::Function*> BuildPackedWrapper(llvm::Function* callee);

  // Simple templates to walk down the arg tree and populate the corresponding
  // arg/buffer pointer.
  template <typename FrontT, typename... RestT>
  void PackArgBuffers(const uint8_t** arg_buffers, uint8_t** result_buffer,
                      FrontT front, RestT... rest) {
    arg_buffers[0] = front.buffer();
    PackArgBuffers(&arg_buffers[1], result_buffer, rest...);
  }

  // Base case for the above recursive template.
  template <typename LastT>
  void PackArgBuffers(const uint8_t** arg_buffers, uint8_t** result_buffer,
                      LastT front) {
    *result_buffer = front.mutable_buffer();
  }

  // Invokes the jitted function with the given argument and outputs.
  template <bool kForceZeroCopy = false>
  void InvokeUnalignedJitFunction(absl::Span<const uint8_t* const> arg_buffers,
                                  uint8_t* output_buffer,
                                  InterpreterEvents* events);

  InterfaceMetadata metadata_;

  std::unique_ptr<OrcJit> orc_jit_;

  JittedFunctionBase jitted_function_base_;

  // Pre-allocated & aligned storage for a set of arguments. Not thread safe.
  std::unique_ptr<JitArgumentSetOwnedBuffer> arg_buffers_;
  // Pre-allocated & aligned storage for a result. Not thread safe.
  std::unique_ptr<JitArgumentSetOwnedBuffer> result_buffers_;
  // Pre-allocated & aligned storage for required temporary storage. NB Not
  // thread safe.
  JitTempBuffer temp_buffer_;

  // Context callbacks.
  InstanceContext callbacks_ = InstanceContext::CreateForFunc();

  std::unique_ptr<JitRuntime> jit_runtime_;

  // Are callbacks for node-values compiled in.
  bool has_observer_callbacks_;
};

}  // namespace xls

#endif  // XLS_JIT_FUNCTION_JIT_H_
